import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer,build_norm_layer

from ..builder import BACKBONES
from ..utils import ResLayer
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNetV1d

class RSoftmax(nn.Module):
    """Radix Softmax module in ''SplitAttentionConv2d'',

    Args:
        radix(int):Radix of input.
        groups(int):Groups of input.
    """

    def __init__(self,radix,groups):
        super().__init__()
        self.radix = radix
        self.groups = groups

    def forward(self,x):
        batch = x.size(0)
        if self.radix>1:
            x = x.view(batch,self.groups,self.radix,-1).transpose(1,2)
            x = F.softmax(x,dim=1)
            x = x.reshape(batch,-1)
        else:
            x = torch.sigmoid(x)
        return x

class SplitAttentionConv2d(nn.Module):
    """
    Split-Attention Conv2d in ResNeSt.

    Args:
        in_channels(int):Number of channels in the input feature map.
        channels(int):Number of intermediate channels.
        kernel_size(int|tuple[int]):Size of the convolution kernel.
        stride(int|tuple[int]):Stride of the convolution.
        padding(int|tuple[int]):Zero-padding added to bo sides of
        dilation(int | tuple[int]):Spacing between kernel elements.
        groups(int):Number of blocked connections from input channels to output channels.
        grps(int):Same as nn.Conv2d.
        radix(int):Radix of SpltAtConv2d.Default:2
        reduction factor(int):Reduction factor of inter_channels.Default:4.
        conv_cfg(dict):Config dict for convolution layer.Default:None,which means using conv2d.
        norm_cfg(dict):Config dict for normalization layer.Default:None.
        dcn(dict):Config dict for DCN.Default:None.
    """

    def __init__(self,
                 in_channels,
                 channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 radix=2,
                 reduction_factor=4,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 dcn=None):
        super(SplitAttentionConv2d, self).__init__()
        inter_channels = max(in_channels*radix//reduction_factor,32)
        self.radix = radix
        self.groups = groups
        self.channels = channels
        self.with_dcn = dcn is not None
        self.dcn = dcn
        fallback_on_stride = False
        if self.with_dcn:
            fallback_on_stride = self.dcn.pop('fallback_on_stride',False)
        if self.with_dcn and not fallback_on_stride:
            assert conv_cfg is None,'conv_cfg must be None for DCN'
            conv_cfg = dcn
        self.conv = build_conv_layer(
            conv_cfg,
            in_channels,
            channels*radix,
            kernel_size,
            stride=stride,
            padding = padding,
            dilation = dilation,
            groups = groups*radix,
            bias=False
        )
        # To be consistent with original implementation,starting from 0
        self.norm0_name,norm0 = build_norm_layer(
            norm_cfg,channels*radix,postfix=0
        )
        self.add_module(self.norm0_name,norm0)
        self.relu = nn.ReLU(inplace=True)
        self.fc1 = build_conv_layer(
            None,channels,inter_channels,1,groups=self.groups
        )
        self.norm1_name,norm1 = build_norm_layer(
            norm_cfg,inter_channels,postfix=1
        )
        self.add_module(self.norm1_name,norm1)
        self.fc2 = build_conv_layer(
            None,inter_channels,channels*radix,1,groups=self.groups
        )
        self.rsoftmax = RSoftmax(radix,groups)

    @property
    def norm0(self):
        """nn.Module:the normalization layer named "norm0" """
        return getattr(self,self.norm0_name)

    @property
    def norm1(self):
        """nn.Module:the normalization layer named "norm1" """
        return getattr(self,self.norm1_name)

    def forward(self,x):
        x = self.conv(x)
        x = self.norm0(x)
        x = self.relu(x)

        batch,rchannel = x.shape[:2]
        batch = x.size(0)
        if self.radix>1:
            splits = x.view(batch,self.radix,-1,*x.shape[2:])  # *加上形参名，表示这个函数的实参个数不定
            gap = splits.sum(dim=1)
        else:
            gap = x
        gap = F.adaptive_avg_pool2d(gap,1)
        gap = self.fc1(gap)

        gap = self.norm1(gap)
        gap = self.relu(gap)

        atten = self.fc2(gap)
        atten = self.rsoftmax(atten).view(batch,-1,1,1)

        if self.radix>1:
            attens = atten.view(batch,self.radix,-1,*atten.shape[2:])
            out = torch.sum(attens*splits,dim=1)
        else:
            out = atten*x
        return out.contiguous()

class Bottleneck(_Bottleneck):
    """Bottleneck block for ResNeSt.

    Args:
        inplane(int):Input planes of this block.
        planes(int):Middle planes of this block.
        groups(int):Groups of conv2.
        base_width(int):Base of width in terms of base channels.Default:4.
        base_channels(int):Base of channels for calculating width.
            Default:64.
        radix(int):Radix of SpltAtConv2d.Default:2
        reduction factor(int):Reduction factor of inter_channels in SplitAttentionConv2d.Default:4.
        avg_down_stride(bool):Whether to use average pool for stride in Bottleneck.Default:True.
        kwargs(dict):Key word arguments for base class.

    """

    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 groups=1,
                 base_width=4,
                 base_channels=64,
                 radix=2,
                 reduction_factor=4,
                 avg_down_stride=True,
                 **kwargs):
        """Bottleneck block for ResNeSt."""
        super(Bottleneck, self).__init__(inplanes,planes,**kwargs)

        if groups==1:
            width = self.planes
        else:
            width = math.floor(self.planes*(base_width/base_channels))*groups

        self.avg_down_stride = avg_down_stride and self.conv2_stride>1

        self.norm1_name,norm1 = build_norm_layer(
            self.norm_cfg,width,postfix=1
        )

        self.norm3_name,norm3 = build_norm_layer(
            self.norm_cfg,self.planes*self.expansion,postfix=3
        )

        self.conv1 = build_conv_layer(
            self.conv_cfg,
            self.inplanes,
            width,
            kernel_size = 1,
            stride = self.conv1_stride,
            bias = False
        )
        self.add_module(self.norm1_name,norm1)
        self.with_modulated_dcn = False
        self.conv2 = SplitAttentionConv2d(
            width,
            width,
            kernel_size=3,
            stride=1 if self.avg_down_stride else self.conv2_stride,
            padding=self.dilation,
            dilation=self.dilation,
            groups = groups,
            radix=radix,
            reduction_factor=reduction_factor,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            dcn = self.dcn
        )
        delattr(self,self.norm2_name)  # delete

        if self.avg_down_stride:
            self.avd_layer = nn.AvgPool2d(3,self.conv2_stride,padding=1)

        self.conv3 = build_conv_layer(
            self.conv_cfg,
            width,
            self.planes*self.expansion,
            kernel_size = 1,
            bias = False
        )
        self.add_module(self.norm3_name,norm3)

    def forward(self,x):

        def _inner_forward(x):
            identity = x
            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu(out)

            if self.with_plugins:
                out = self.forward_plugin(out,self.after_conv1_plugin_names)

            out = self.conv2(out)

            if self.avg_down_stride:
                out = self.avd_layer(out)

            if self.with_plugins:
                out = self.forward_plugin(out,self.after_conv2_plugin_names)

            out = self.conv3(out)
            out = self.norm3(out)

            if self.with_plugins:
                out = self.forward_plugin(out,self.after_conv3_plugin_names)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward,x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out

@BACKBONES.register_module()
class ResNeSt(ResNetV1d):
    """ResNeSt backbone.

    Args:
        groups(int):Number of groups  of Bottleneck.Default:1
        base_width(int):Base width of Bottleneck.Default:4
        radix(int):Radix of SplitAttentionConv2d.Default:2
        reduction_factor(int):Reduction factor of inter_channels in SplitAttentionConv2d.Default:4.
        avg_down_stride(bool):Whether to use average pool for stride in Bottleneck.Default:True.
        kwargs(dict):Keyword arguments for ResNet.
    """

    arch_settings = {
        50:(Bottleneck,(3,4,6,3)),
        101:(Bottleneck,(3,4,23,3)),
        152:(Bottleneck,(3,8,36,3)),
        200:(Bottleneck,(3,24,36,3))
    }

    def __init__(self,
                 groups=1,
                 base_width=4,
                 radix=2,
                 reduction_factor=4,
                 avg_down_stride=True,
                 **kwargs):
        self.groups = groups
        self.base_width = base_width
        self.radix = radix
        self.reduction_factor = reduction_factor
        self.avg_down_stride = avg_down_stride
        super(ResNeSt, self).__init__(**kwargs)

    def make_res_layer(self, **kwargs):
        """Pack all blocks in a stage into a ''ResLayer''."""
        return ResLayer(
            groups = self.groups,
            base_width = self.base_width,
            base_channels = self.base_channels,
            radix = self.radix,
            reduction_factor = self.reduction_factor,
            avg_down_stride = self.avg_down_stride,
            **kwargs
        )